from copy import deepcopy
from pathlib import Path
from typing import Optional
import json
import os

import torch
from hydra.utils import call
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Subset, Dataset
# from torchvision import transforms as transform_lib
# from torchvision.datasets import CIFAR10, CIFAR100



from collections import defaultdict
import numpy as np
from torch.utils.data import Dataset
import torch
# from src.utils.language_utils import word_to_indices, letter_to_vec

from src.data.data_utils import split_subsets_train_val, split_dataset_train_val, add_attrs


class ShakeSpeareDataModule(LightningDataModule):
    """Standard CIFAR, train, val, test splits and transforms.
    >>> ShakeSpeareDataModule()  # doctest: +ELLIPSIS
    <...CIFAR_datamodule.CIFARDataModule object at ...>
    """

    name = "ShakeSpeare"

    def __init__(
            self,
            val_split: float = 0.1,
            num_workers: int = 16,
            seed: int = 42,
            batch_size: int = 32,
            num_clients: int = 10,
            fair_val: bool = False,
            client_idx: int = 0,
            *args,
            **kwargs,
    ):
        """
        Args:
            data_dir: where to save/load the data
            val_split: how many of the training images to use for the validation split
            num_workers: how many workers to use for loading data
            normalize: If true applies image normalize
            seed: starting seed for RNG.
            batch_size: desired batch size.
        """
        super().__init__(*args, **kwargs)

        self.dataset_train: Dataset = ...
        self.dataset_val: Dataset = ...
        self.dataset_test: Dataset = ...
        self.dict_users = {}
        self.Vdict_users = {}

        self.val_split = val_split
        self.num_workers = num_workers
        self.seed = seed
        self.batch_size = batch_size
        self.num_clients = num_clients
        self.fair_val = fair_val
        self.client_idx = client_idx

    # def prepare_data(self):
    #     """Saves CIFAR files to `data_dir`"""
    #     self.dataset(self.data_dir, train=True, download=True)
    #     self.dataset(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):
        """Split the train and valid dataset."""

        self.dataset_train = ShakeSpeare(train=True)
        self.dataset_val = ShakeSpeare(train=True,val=True)
        self.dataset_test = ShakeSpeare(train=False)
        self.dict_users = self.dataset_train.get_client_dic()
        self.Vdict_users = self.dataset_val.get_client_Vdic()
        self.num_clients = len(self.dict_users)



    def next_client(self):
        self.client_idx += 1
        assert self.client_idx < self.num_clients, "Client number shouldn't excced seleced number of clients"

    def train_dataloader(self):
        print(f"client_idx: {self.client_idx}")
        loader = DataLoader(DatasetSplit(self.dataset_train, self.dict_users[self.client_idx]), batch_size=self.batch_size, shuffle=True)
        return loader

    def val_dataloader(self):
        print(f"client_idx: {self.client_idx} Val_set len {len(self.Vdict_users[self.client_idx])}")
        loader = DataLoader(DatasetSplit(self.dataset_val, self.Vdict_users[self.client_idx]), batch_size=self.batch_size, shuffle=False)
        return loader

    def test_dataloader(self):
        loader = DataLoader(self.dataset_test, batch_size=self.batch_size)
        return loader



class ShakeSpeare(Dataset):
    def __init__(self, train=True, val_split=0.2, val=False):
        super(ShakeSpeare, self).__init__()
        print(f"Current working directory : {os.getcwd()}")
        path = "/home/alballns/norah/KDN_o/data/shakespeare"
        # train_clients, train_groups, train_data_temp, test_data_temp = read_data("/data/shakespeare/train",
        #                                                                          "/data/shakespeare/test")
        train_clients, train_groups, train_data_temp, test_data_temp = read_data(f"{path}/train",
                                                                                 f"{path}/test")
        self.train = train
        self.val = val

        if self.train:
            self.dic_users = {}
            self.Vdic_users = {}

            self.train_dic_users = {}
            train_data_x = []
            train_data_y = []

            val_data_x = []
            val_data_y = []
            for i in range(len(train_clients)):
                # if i == 100:
                #     break
                self.dic_users[i] = set()
                self.Vdic_users[i] = set()
                l = len(train_data_x)

                cur_x = train_data_temp[train_clients[i]]['x']
                cur_y = train_data_temp[train_clients[i]]['y']

                lenX = len(cur_x)
                lenV = int(val_split * lenX)
                lenT = lenX - lenV

                #                 print(f"lenX: {lenX}, lenV: {lenV}, lenT: {lenT}")

                for j in range(len(cur_x)):
                    if j < lenT:
                        self.dic_users[i].add(j + l)
                        train_data_x.append(cur_x[j])
                        train_data_y.append(cur_y[j])
                    else:
                        self.Vdic_users[i].add(j + l)
                        val_data_x.append(cur_x[j])
                        val_data_y.append(cur_y[j])
            #                 print(f"client {i}: train set len {len(train_data_x)}, val set len {len(val_data_x)}")

            self.data = train_data_x
            self.label = train_data_y

            l = len(train_data_x)
            lv = len(val_data_x)

            #             print(f"train set len {l}, val set len {lv}")
            #             print(f"dic_users[0]: {self.dic_users[0]}")
            #             print(f"Vdic_users[0]: {self.Vdic_users[0]}")

            if val:
                print(">> val set")
                self.data = val_data_x
                self.label = val_data_y

        else:
            test_data_x = []
            test_data_y = []
            for i in range(len(train_clients)):
                cur_x = test_data_temp[train_clients[i]]['x']
                cur_y = test_data_temp[train_clients[i]]['y']
                for j in range(len(cur_x)):
                    test_data_x.append(cur_x[j])
                    test_data_y.append(cur_y[j])

            print(">> test set")
            self.data = test_data_x
            self.label = test_data_y

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sentence, target = self.data[index], self.label[index]
        indices = word_to_indices(sentence)
        target = letter_to_vec(target)
        # y = indices[1:].append(target)
        # target = indices[1:].append(target)
        indices = torch.LongTensor(np.array(indices))
        # y = torch.Tensor(np.array(y))
        # target = torch.LongTensor(np.array(target))
        return indices, target

    def get_client_dic(self):
        if self.train:
            dic = self.dic_users
            return dic
        else:
            exit("The test dataset do not have dic_users!")

    def get_client_Vdic(self):
        if self.train:
            dic = self.dic_users
            if self.val:
                dic = self.Vdic_users
            return dic
        else:
            exit("The test dataset do not have dic_users!")

def batch_data(data, batch_size, seed):
        '''
        data is a dict := {'x': [numpy array], 'y': [numpy array]} (on one client)
        returns x, y, which are both numpy array of length: batch_size
        '''
        data_x = data['x']
        data_y = data['y']

        # randomly shuffle data
        np.random.seed(seed)
        rng_state = np.random.get_state()
        np.random.shuffle(data_x)
        np.random.set_state(rng_state)
        np.random.shuffle(data_y)

        # loop through mini-batches
        for i in range(0, len(data_x), batch_size):
            batched_x = data_x[i:i + batch_size]
            batched_y = data_y[i:i + batch_size]
            yield (batched_x, batched_y)

def read_dir(data_dir):
        clients = []
        groups = []
        data = defaultdict(lambda: None)

        files = os.listdir(data_dir)
        files = [f for f in files if f.endswith('.json')]
        for f in files:
            file_path = os.path.join(data_dir, f)
            with open(file_path, 'r') as inf:
                cdata = json.load(inf)
            clients.extend(cdata['users'])
            if 'hierarchies' in cdata:
                groups.extend(cdata['hierarchies'])
            data.update(cdata['user_data'])

        clients = list(sorted(data.keys()))
        return clients, groups, data

def read_data(train_data_dir, test_data_dir):
        '''parses data in given train and test data directories

        assumes:
        - the data in the input directories are .json files with
            keys 'users' and 'user_data'
        - the set of train set users is the same as the set of test set users

        Return:
            clients: list of client ids
            groups: list of group ids; empty list if none found
            train_data: dictionary of train data
            test_data: dictionary of test data
        '''
        train_clients, train_groups, train_data = read_dir(train_data_dir)
        test_clients, test_groups, test_data = read_dir(test_data_dir)

        assert train_clients == test_clients
        assert train_groups == test_groups

        return train_clients, train_groups, train_data, test_data


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


ALL_LETTERS = "\n !\"&'(),-.0123456789:;>?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]abcdefghijklmnopqrstuvwxyz}"
NUM_LETTERS = len(ALL_LETTERS)


# print(NUM_LETTERS)

def _one_hot(index, size):
    '''returns one-hot vector with given size and value 1 at given index
    '''
    vec = [0 for _ in range(size)]
    vec[int(index)] = 1
    return vec


def letter_to_vec(letter):
    '''returns one-hot representation of given letter
    '''
    index = ALL_LETTERS.find(letter)
    return index


def word_to_indices(word):
    '''returns a list of character indices

    Args:
        word: string

    Return:
        indices: int list with length len(word)
    '''
    indices = []
    for c in word:
        indices.append(ALL_LETTERS.find(c))
    return indices

